/* C.Bag: Bag data type */

#include <stddef.h>
#include <stdlib.h>
#include <string.h>
#include "bag.h"

#ifdef test
#include <stdio.h>
#endif

struct link
{
        struct link *next;
        int count;
        char data[1];
};

typedef struct link *link;

/* Return values from functions */

#define OK      1
#define ERR     0

/* Utility function - find an element in a bag */

static link find (const bag b, const void *element, link *prev)
{
        link this;
        link p = NULL;
        const int size = b->obj_size;

        for ( this = b->first; this != NULL; this = this->next )
        {
                if ( memcmp(element,this->data,size) == 0 )
                {
                        if ( prev != NULL )
                                *prev = p;
                                
                        return this;
                }

                p = this;
        }

        return NULL;
}

/* General component routines */

bag bag_new (int obj_len)
{
        register bag b;

        b = malloc(sizeof(struct bag));

        if ( b == NULL )
                return NULL;

        b->first    = NULL;
        b->obj_size = obj_len;

        return b;
}

void bag_free (bag b)
{
        bag_clear(b);
        free(b);
}

void bag_clear (bag b)
{
        link this_entry = b->first;
        link next_entry;
        
        while ( this_entry != NULL )
        {
                next_entry = this_entry->next;
                free(this_entry);
                this_entry = next_entry;
        }

        b->first = NULL;
}

int bag_copy (bag b1, const bag b2)
{
        link p;
        link new;
        link last;
        int size;

        if ( b1->obj_size != b2->obj_size )
                return ERR;

        bag_clear(b1);

        last = (link)b1;
        size = b2->obj_size;

        for ( p = b2->first; p != NULL; p = p->next )
        {
                new = malloc(sizeof(struct link) - 1 + size);
                if ( new == NULL )
                {
                        bag_clear(b1);
                        return ERR;
                }
                last->next = new;
                memcpy(new->data,p->data,size);
                new->count = p->count;
                last = new;
        }

        last->next = NULL;
        return OK;
}

int bag_equal (const bag b1, const bag b2)
{
        link p;
        link q;
        int n1 = 0;
        int n2 = 0;

        if ( b1->obj_size != b2->obj_size )
                return 0;

        /* For every element of b1, look for it in b2 */

        for ( p = b1->first; p != NULL; p = p->next )
        {
                q = find(b2,p->data,NULL);

                /* If it's not in b2, or the counts are different, b1 != b2 */
                if ( q == NULL || p->count != q->count )
                        return 0;

                /* Count the unique elements of b1 */
                ++n1;
        }

        /* Count the unique elements of b2 */
        for ( p = b2->first; p != NULL; p = p->next )
                ++n2;

        /* The bags differ if there are elements in b1 not in b2 */
        return ( n1 == n2 );
}

int bag_empty (const bag b)
{
        return ( b->first == NULL );
}

int bag_size (const bag b)
{
        int i = 0;
        link p;

        for ( p = b->first; p != NULL; p = p->next )
                i += p->count;

        return i;
}

int bag_iterate (const bag b, int (*process)(void *, int))
{
        int ret = 0;
        link p;

        for ( p = b->first; p != NULL; p = p->next )
        {
                ret = (*process)(p->data,p->count);

                /* Non-zero => stop processing here */

                if ( ret != 0 )
                        break;
        }

        /* Negative => Abnormal (error) termination */

        return ( ret >= 0 );
}

/* bag-specific routines */

int bag_add (bag b, const void *object)
{
        link new;

        new = find(b,object,NULL);

        if ( new != NULL )
        {
                ++new->count;
                return OK;
        }

        new = malloc(sizeof(struct link) - 1 + b->obj_size);

        if ( new == NULL )
                return ERR;

        memcpy(new->data,object,b->obj_size);
        new->count = 1;

        new->next = b->first;
        b->first = new;

        return OK;
}

int bag_remove (bag b, const void *object)
{
        link p;
        link prev;

        p = find(b,object,&prev);

        if ( p == NULL )
                return ERR;

        if ( p->count > 1 )
        {
                --p->count;
                return OK;
        }

        if ( prev == NULL )
                b->first = p->next;
        else
                prev->next = p->next;

        free(p);

        return OK;
}

int bag_member (const bag b, const void *object)
{
        return ( find(b,object,NULL) != NULL );
}

int bag_count (const bag b, const void *object)
{
        link p = find(b,object,NULL);

        return ( p != NULL ? p->count : 0 );
}

int bag_union (bag b1, const bag b2, const bag b3)
{
        link p;
        link new;

        /* Check with b2's length occurs in bag_copy */
        if ( b1->obj_size != b3->obj_size )
                return ERR;

        if ( !bag_copy(b1,b2) )
                return ERR;

        for ( p = b3->first; p != NULL; p = p->next )
        {
                new = find(b1,p->data,NULL);

                if ( new != NULL )
                {
                        new->count += p->count;
                        continue;
                }

                new = malloc(sizeof(struct link) - 1 + b1->obj_size);

                if ( new == NULL )
                        return ERR;

                memcpy(new->data,p->data,b1->obj_size);
                new->count = 1;

                new->next = b1->first;
                b1->first = new;
        }

        return OK;
}

int bag_intersection (bag b1, const bag b2, const bag b3)
{
        link p;
        link q;
        link new;

        if ( b1->obj_size != b2->obj_size || b1->obj_size != b3->obj_size )
                return ERR;

        bag_clear(b1);

        for ( p = b2->first; p != NULL; p = p->next )
        {
                q = find(b3,p->data,NULL);

                if ( q == NULL )
                        continue;

                new = malloc(sizeof(struct link) - 1 + b1->obj_size);

                if ( new == NULL )
                        return ERR;

                memcpy(new->data,p->data,b1->obj_size);
                new->count = ( p->count < q->count ? p->count : q->count );

                new->next = b1->first;
                b1->first = new;
        }

        return OK;
}

int bag_difference (bag b1, const bag b2, const bag b3)
{
        link p;
        link q;
        link new;

        if ( b1->obj_size != b2->obj_size || b1->obj_size != b3->obj_size )
                return ERR;

        bag_clear(b1);

        for ( p = b2->first; p != NULL; p = p->next )
        {
                q = find(b3,p->data,NULL);

                if ( q != NULL && p->count <= q->count )
                        continue;

                new = malloc(sizeof(struct link) - 1 + b1->obj_size);

                if ( new == NULL )
                        return ERR;

                memcpy(new->data,p->data,b1->obj_size);
                new->count = p->count;
                if ( q != NULL )
                        new->count -= q->count;

                new->next = b1->first;
                b1->first = new;
        }

        return OK;
}

int bag_unique_count (const bag b)
{
        int i = 0;
        link p;

        for ( p = b->first; p != NULL; p = p->next )
                ++i;

        return i;
}

int bag_subset (const bag b1, const bag b2)
{
        link p;
        link q;

        if ( b1->obj_size != b2->obj_size )
                return 0;

        /* For every element of b1, look for it in b2 */

        for ( p = b1->first; p != NULL; p = p->next )
        {
                q = find(b2,p->data,NULL);

                /* If it's not in b2, or in less times, b1 is not a subset */
                if ( q == NULL || p->count > q->count )
                        return 0;
        }

        return 1;
}

int bag_proper_subset (const bag b1, const bag b2)
{
        link p;
        int n1 = 0;
        int n2 = 0;

        if ( b1->obj_size != b2->obj_size )
                return 0;

        /* For every element of b1, look for it in b2 */

        for ( p = b1->first; p != NULL; p = p->next )
        {
                /* If it's not in b2, b1 is not a subset */
                if ( find(b2,p->data,NULL) == NULL )
                        return 0;

                /* Count the elements of b1 */
                n1 += p->count;
        }

        /* Count the elements of b2 */
        n2 = bag_size(b2);

        /* It is only a proper subset if there are elements of b2 not in b1 */
        return ( n1 < n2 );
}

/*---------------------------------------------------------------------------*/

#ifdef test
int print (void *ptr, int n)
{
        while ( n-- > 0 )
                printf("%d ",*(int *)ptr);
        return STATUS_CONTINUE;
}

void bag_dump (bag b)
{
        printf("bag: ");
        bag_iterate(b,print);
        putchar('\n');
}
#endif

/*---------------------------------------------------------------------------*/

#ifdef test

#define BUFLEN 255

int main (void)
{
        char buf[BUFLEN];
        int i, j, k, num;
        bag b[10];

        for ( i = 0; i < 10; ++i )
                b[i] = bag_new(sizeof(int));

        for ( ; ; )
        {
                printf(">");
                fgets(buf,BUFLEN,stdin);

                if ( buf[0] == '\n' || buf[0] == '\0' )
                        continue;
                else if ( sscanf(buf,"clear %1d",&i) == 1 )
                        bag_clear(b[i]);
                else if ( sscanf(buf,"copy %1d %1d",&i,&j) == 2 )
                        bag_copy(b[i],b[j]);
                else if ( sscanf(buf,"equal %1d %1d",&i,&j) == 2 )
                        printf("%s\n",(bag_equal(b[i],b[j]) ? "yes" : "no"));
                else if ( sscanf(buf,"empty %1d",&i) == 1 )
                        printf("%s\n",(bag_empty(b[i]) ? "yes" : "no"));
                else if ( sscanf(buf,"size %1d",&i) == 1 )
                        printf("%d\n",bag_size(b[i]));
                else if ( sscanf(buf,"dump %1d",&i) == 1 )
                        bag_dump(b[i]);
                else if ( sscanf(buf,"add %1d %d",&i,&num) == 2 )
                        bag_add(b[i],&num);
                else if ( sscanf(buf,"remove %1d %d",&i,&num) == 2 )
                        bag_remove(b[i],&num);
                else if ( sscanf(buf,"member %1d %d",&i,&num) == 2 )
                        printf("%s\n", bag_member(b[i],&num) ? "yes" : "no");
                else if ( sscanf(buf,"count %1d %d",&i,&num) == 2 )
                        printf("%d\n", bag_count(b[i],&num));
                else if ( sscanf(buf,"union %1d %1d %1d",&i,&j,&k) == 3 )
                        bag_union(b[i],b[j],b[k]);
                else if ( sscanf(buf,"intersection %1d %1d %1d",&i,&j,&k) == 3 )
                        bag_intersection(b[i],b[j],b[k]);
                else if ( sscanf(buf,"difference %1d %1d %1d",&i,&j,&k) == 3 )
                        bag_difference(b[i],b[j],b[k]);
                else if ( sscanf(buf,"unique count %1d",&i) == 1 )
                        printf("%d\n", bag_unique_count(b[i]));
                else if ( sscanf(buf,"subset %1d %1d",&i,&j) == 2 )
                        printf("%s\n", bag_subset(b[i],b[j]) ? "yes" : "no");
                else if ( sscanf(buf,"proper subset %1d %1d",&i,&j) == 2 )
                        printf("%s\n", bag_proper_subset(b[i],b[j]) ? "yes" : "no");
                else if ( strncmp(buf,"help",4) == 0 )
                        printf(
                                "clear i\n"
                                "copy i j\n"
                                "equal i j\n"
                                "empty i\n"
                                "size i\n"
                                "dump i\n"
                                "add i n\n"
                                "remove i n\n"
                                "member i n\n"
                                "count i n\n"
                                "union i j k\n"
                                "intersection i j k\n"
                                "difference i j k\n"
                                "unique count i\n"
                                "subset i j\n"
                                "proper subset i j\n"
                              );
                else if ( strncmp(buf,"quit",4) == 0 )
                        break;
                else
                        printf("Mistake\n");
        }

        printf("Deleting b[0-9]\n");
        for ( i = 0; i < 10; ++i )
                bag_free(b[i]);

        return 0;
}

#endif
